/*
 * Copyright (c) 2023 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "dyn_services_card_channel_inner.h"

#include <stddef.h>
#include <string.h>

#include <securec.h>
#include <tee_defines.h>
#include <tee_internal_se_api.h>

#include "dyn_services_card_channel_scp.h"
#include "logger.h"
#include "se_base_services_defines.h"

#define WRAP_RET(teeRet, serviceRet) (((teeRet)&0xFFFF) | ((serviceRet) << 16U))

TEE_SEReaderHandle FindSeReaderByName(TEE_SEServiceHandle service, const char *readerName)
{
    if (service == NULL || readerName == NULL) {
        LOG_ERROR("input is nullptr");
        return NULL;
    }

    uint32_t readerCount = MAX_SE_READER_COUNT;
    TEE_SEReaderHandle readers[MAX_SE_READER_COUNT] = {0};
    TEE_Result ret = TEE_SEServiceGetReaders(service, readers, &readerCount);
    if (ret != TEE_SUCCESS || readerCount == 0) {
        LOG_ERROR("service get reader error");
        return NULL;
    }
    TEE_SEReaderHandle reader = NULL;

    for (uint32_t i = 0; i < readerCount; i++) {
        char curr[MAX_SE_READER_NAME_SIZE + 1] = {0};
        uint32_t length = MAX_SE_READER_NAME_SIZE;
        if (TEE_SEReaderGetName(readers[i], curr, &length) != TEE_SUCCESS) {
            continue;
        }

        if (strcmp(curr, readerName) == 0) {
            reader = readers[i];
            break;
        }
    }
    return reader;
}

SecureElementContext *CreateSecureElementContext(const char *readerName, const AppIdentifier *identifier)
{
    if (readerName == NULL || identifier == NULL) {
        return NULL;
    }

    SecureElementContext *context = malloc(sizeof(SecureElementContext));
    if (context == NULL) {
        LOG_ERROR("malloc error");
        return NULL;
    }
    (void)memset_s(context, sizeof(SecureElementContext), 0, sizeof(SecureElementContext));

    if (memcpy_s(context->readerName, MAX_SE_READER_NAME_SIZE, readerName, strlen(readerName)) != EOK) {
        LOG_ERROR("memcpy_s readerName error");
        free(context);
        return NULL;
    }

    if (memcpy_s(context->aid, AID_LENGTH_MAX, identifier->aid, identifier->aidLen) != EOK) {
        LOG_ERROR("memcpy_s identifier error");
        free(context);
        return NULL;
    }

    context->aidLen = identifier->aidLen;

    return context;
}

void DestroySecureElementContext(SecureElementContext *context)
{
    if (context == NULL) {
        return;
    }

    (void)SeChannelClose(context);
    (void)memset_s(context, sizeof(SecureElementContext), 0, sizeof(SecureElementContext));
    free(context);
}

ResultCode SeChannelOpen(SecureElementContext *context)
{
    if (context == NULL) {
        LOG_ERROR("input is nullptr");
        return INVALID_PARA_NULL_PTR;
    }

    bool success = false;
    TEE_Result ret = TEE_FAIL;
    do {
        ret = TEE_SEServiceOpen(&context->service);
        if (ret != TEE_SUCCESS) {
            LOG_ERROR("service open error, ret is 0x%x", ret);
            break;
        }
        const char *readerName = context->readerName;
        context->reader = FindSeReaderByName(context->service, readerName);
        if (context->reader == NULL) {
            ret = TEE_ERROR_ITEM_NOT_FOUND;
            LOG_ERROR("service get reader %s failed", readerName);
            break;
        }

        ret = TEE_SEReaderOpenSession(context->reader, &context->session);
        if (ret != TEE_SUCCESS) {
            LOG_ERROR("service open session error for card %s, ret is 0x%x", readerName, ret);
            break;
        }

        TEE_SEAID seAid = {.buffer = context->aid, .bufferLen = context->aidLen};

        ret = TEE_SESessionOpenLogicalChannel(context->session, &seAid, &context->channel);
        if (ret != TEE_SUCCESS) {
            LOG_ERROR("service open channel error for card %s, ret is 0x%x", readerName, ret);
            break;
        }

        uint8_t response[RES_APDU_MAX_LEN];
        uint32_t length = RES_APDU_MAX_LEN;
        ret = TEE_SEChannelGetSelectResponse(context->channel, response, &length);
        if (ret != TEE_SUCCESS) {
            LOG_ERROR("service open channel error for card %s, ret is 0x%x", readerName, ret);
            break;
        }

        if (!IsApduResponseSuccess(response, length)) {
            LOG_ERROR("IsApduResponseSuccess failed");
            break;
        }
        success = true;
    } while (0);

    if (!success) {
        SeChannelClose(context);
        return WRAP_RET(ret, CHN_OPEN_ERR);
    }

    return SUCCESS;
}

ResultCode SeChannelClose(SecureElementContext *context)
{
    if (context == NULL) {
        return INVALID_PARA_NULL_PTR;
    }
    if (context->channel != NULL) {
        TEE_SEChannelClose(context->channel);
        context->channel = NULL;
    }
    if (context->session != NULL) {
        TEE_SESessionClose(context->session);
        context->session = NULL;
    }
    if (context->reader != NULL) {
        TEE_SEReaderCloseSessions(context->reader);
        context->reader = NULL;
    }
    if (context->service != NULL) {
        TEE_SEServiceClose(context->service);
        context->service = NULL;
    }
    return SUCCESS;
}

ResultCode SeSecureChannelOpen(SecureElementContext *context, const uint8_t *key, uint32_t keyLen, uint8_t keyVersion,
    uint8_t keyId)
{
    if (context == NULL) {
        LOG_ERROR("input context error");
        return INVALID_PARA_NULL_PTR;
    }

    TEE_SC_Params *params = CreateScpParams(key, keyLen, keyVersion, keyId);
    if (params == NULL) {
        return INVALID_PARA_ERR_VALUE;
    }

    TEE_Result open = TEE_SESecureChannelOpen(context->channel, params);
    if (open != TEE_SUCCESS) {
        LOG_ERROR("TEE_SESecureChannelOpen error = 0x%x", open);
        DeleteScpParams(params);
        SeChannelClose(context);
        return WRAP_RET(open, CHN_OPEN_ERR);
    }
    LOG_INFO("TEE_SESecureChannelOpen success");
    DeleteScpParams(params);
    return SUCCESS;
}

ResultCode SeSecureChannelClose(SecureElementContext *context)
{
    if (context == NULL) {
        return INVALID_PARA_NULL_PTR;
    }
    if (context->channel != NULL) {
        // TEE_SESecureChannelClose currently is not exported
    }
    return SUCCESS;
}

ResultCode SeChannelTransmit(SecureElementContext *context, const uint8_t *command, uint32_t commandLen,
    uint8_t *response, uint32_t *responseLen)
{
    if (context == NULL || context->channel == NULL) {
        LOG_ERROR("input context error");
        return INVALID_PARA_NULL_PTR;
    }

    if (command == NULL || response == NULL || responseLen == NULL) {
        LOG_ERROR("input command or response ptr error");
        return INVALID_PARA_NULL_PTR;
    }

    if (commandLen == 0 || *responseLen == 0) {
        LOG_ERROR("input commandLen or responseLen error");
        return INVALID_PARA_ERR_SIZE;
    }

    // cannot cast (const uint8_t *) to (void *), so just copy it.
    uint8_t cmd[MAX_INPUT_COMMAND_LEN] = {0};
    if (memcpy_s(cmd, MAX_INPUT_COMMAND_LEN, command, commandLen) != EOK) {
        LOG_ERROR("memcpy command error, commandLen is %u", commandLen);
        return MEM_COPY_ERR;
    }

    TEE_Result ret = TEE_SEChannelTransmit(context->channel, cmd, commandLen, response, responseLen);
    if (ret != TEE_SUCCESS) {
        LOG_ERROR("TEE_SEChannelTransmit err = 0x%x", ret);
        return WRAP_RET(ret, CMD_APDU_TRANS_ERR);
    }

    return SUCCESS;
}

ResultCode SeChannelChannelGetOpenResponse(SecureElementContext *context, uint8_t *response, uint32_t *responseLen)
{
    if (context == NULL || context->channel == NULL) {
        LOG_ERROR("input context error");
        return INVALID_PARA_NULL_PTR;
    }

    if (response == NULL || responseLen == NULL) {
        LOG_ERROR("input command or response ptr error");
        return INVALID_PARA_NULL_PTR;
    }

    if (*responseLen < MIN_APDU_RESP_SIZE) {
        LOG_ERROR("input response length error");
        return INVALID_PARA_ERR_SIZE;
    }

    TEE_Result ret = TEE_SEChannelGetSelectResponse(context->channel, response, responseLen);
    if (ret != TEE_SUCCESS) {
        LOG_ERROR("TEE_SEChannelGetSelectResponse err = %u", ret);
        return CHN_OPEN_ERR;
    }
    return SUCCESS;
}

bool IsApduResponseSuccess(const uint8_t *response, uint32_t length)
{
    if (response == NULL) {
        return false;
    }
    if (length < RES_APDU_MIN_LEN || length > RES_APDU_MAX_LEN) {
        return false;
    }

    uint8_t sw1 = response[length - 2];
    uint8_t sw2 = response[length - 1];
    uint16_t statusWords = (uint16_t)sw1 * 256 + (uint16_t)sw2;
    return (statusWords == SW_NO_ERROR);
}